In [1]:
import sys
import os

import numpy as np
import pandas as pd
import tensorflow as tf
from pathlib import Path
import matplotlib.pyplot as plt

import importlib
if '/opt/notebooks/' not in sys.path:
    sys.path.append('/opt/notebooks/')

try: importlib.reload(sl)
except: import synt_lib as sl
In [2]:
DIRS = sl.get_dirs()
M_PARAMS = sl.get_model_params()

Get data for testing

wav_fnames = Path(DIRS['RAW_DATA']).rglob("*.wav") fname = wav_fnames.__next__().as_posix()X = sl.load_audio_one_hot(fname)sess = tf.Session()X.eval(session=sess)sl.write_audio_one_hot('sample.wav', X, sess)X = tf.concat(X, axis=0)[0]

A very basic network

Set model

In [3]:
num_epochs = 20
total_series_length = 50000
truncated_backprop_length = 15
state_size = 4
num_classes = 2
echo_step = 3
batch_size = 5
num_batches = total_series_length//batch_size//truncated_backprop_length
In [3]:
def generateData():
    x = np.array(np.random.choice(2, total_series_length, p=[0.5, 0.5]))
    y = np.roll(x, echo_step)
    y[0:echo_step] = 0

    x = x.reshape((batch_size, -1))  # The first index changing slowest, subseries as rows
    y = y.reshape((batch_size, -1))

    return (x, y)

Variables and placeholders

In [5]:
batchX_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_backprop_length])
batchY_placeholder = tf.placeholder(tf.int32, [batch_size, truncated_backprop_length])

init_state = tf.placeholder(tf.float32, [batch_size, state_size])
In [6]:
W = tf.Variable(np.random.rand(state_size+1, state_size), dtype=tf.float32)
b = tf.Variable(np.zeros((1,state_size)), dtype=tf.float32)

W2 = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)
b2 = tf.Variable(np.zeros((1,num_classes)), dtype=tf.float32)

Unpacking (or unstacking)

In [7]:
inputs_series = tf.unstack(batchX_placeholder, axis=1)
labels_series = tf.unstack(batchY_placeholder, axis=1)
In [8]:
inputs_series
Out[8]:
[<tf.Tensor 'unstack:0' shape=(5,) dtype=float32>,
 <tf.Tensor 'unstack:1' shape=(5,) dtype=float32>,
 <tf.Tensor 'unstack:2' shape=(5,) dtype=float32>,
 <tf.Tensor 'unstack:3' shape=(5,) dtype=float32>,
 <tf.Tensor 'unstack:4' shape=(5,) dtype=float32>,
 <tf.Tensor 'unstack:5' shape=(5,) dtype=float32>,
 <tf.Tensor 'unstack:6' shape=(5,) dtype=float32>,
 <tf.Tensor 'unstack:7' shape=(5,) dtype=float32>,
 <tf.Tensor 'unstack:8' shape=(5,) dtype=float32>,
 <tf.Tensor 'unstack:9' shape=(5,) dtype=float32>,
 <tf.Tensor 'unstack:10' shape=(5,) dtype=float32>,
 <tf.Tensor 'unstack:11' shape=(5,) dtype=float32>,
 <tf.Tensor 'unstack:12' shape=(5,) dtype=float32>,
 <tf.Tensor 'unstack:13' shape=(5,) dtype=float32>,
 <tf.Tensor 'unstack:14' shape=(5,) dtype=float32>]

Forward pass

In [9]:
# Forward pass
current_state = init_state
states_series = []
for current_input in inputs_series:
    current_input = tf.reshape(current_input, [batch_size, 1])
    input_and_state_concatenated = tf.concat([current_input, current_state], axis=1)  # Increasing number of columns

    next_state = tf.tanh(tf.matmul(input_and_state_concatenated, W) + b)  # Broadcasted addition
    states_series.append(next_state)
    current_state = next_state

Calculating loss

In [10]:
logits_series = [tf.matmul(state, W2) + b2 for state in states_series] #Broadcasted addition
predictions_series = [tf.nn.softmax(logits) for logits in logits_series]

losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) for logits, labels in zip(logits_series,labels_series)]
total_loss = tf.reduce_mean(losses)

train_step = tf.train.AdagradOptimizer(0.3).minimize(total_loss)

Visualizing the training

In [4]:
def plot(loss_list, predictions_series, batchX, batchY):
    plt.subplot(2, 3, 1)
    plt.cla()
    plt.plot(loss_list)

    for batch_series_idx in range(5):
        one_hot_output_series = np.array(predictions_series)[:, batch_series_idx, :]
        single_output_series = np.array([(1 if out[0] < 0.5 else 0) for out in one_hot_output_series])

        plt.subplot(2, 3, batch_series_idx + 2)
        plt.cla()
        plt.axis([0, truncated_backprop_length, 0, 2])
        left_offset = range(truncated_backprop_length)
        plt.bar(left_offset, batchX[batch_series_idx, :], width=1, color="blue")
        plt.bar(left_offset, batchY[batch_series_idx, :] * 0.5, width=1, color="red")
        plt.bar(left_offset, single_output_series * 0.3, width=1, color="green")

    plt.draw()
    plt.pause(0.0001)

Running a training session

In [12]:
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    plt.ion()
    plt.figure()
    plt.show()
    loss_list = []

    for epoch_idx in range(num_epochs):
        x,y = generateData()
        _current_state = np.zeros((batch_size, state_size))

        print("New data, epoch", epoch_idx)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * truncated_backprop_length
            end_idx = start_idx + truncated_backprop_length

            batchX = x[:,start_idx:end_idx]
            batchY = y[:,start_idx:end_idx]

            _total_loss, _train_step, _current_state, _predictions_series = sess.run(
                [total_loss, train_step, current_state, predictions_series],
                feed_dict={
                    batchX_placeholder:batchX,
                    batchY_placeholder:batchY,
                    init_state:_current_state
                })

            loss_list.append(_total_loss)

            if batch_idx%100 == 0:
                print("Step",batch_idx, "Loss", _total_loss)
                plot(loss_list, _predictions_series, batchX, batchY)

plt.ioff()
plt.show()
WARNING:tensorflow:From /root/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/tf_should_use.py:189: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
<Figure size 432x288 with 0 Axes>
New data, epoch 0
Step 0 Loss 0.8556402
Step 100 Loss 0.6905209
Step 200 Loss 0.6992068
Step 300 Loss 0.70954394
Step 400 Loss 0.6922239
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-12-1e06b31df1e9> in <module>
     31             if batch_idx%100 == 0:
     32                 print("Step",batch_idx, "Loss", _total_loss)
---> 33                 plot(loss_list, _predictions_series, batchX, batchY)
     34 
     35 plt.ioff()

<ipython-input-11-a7a221093bf1> in plot(loss_list, predictions_series, batchX, batchY)
     17 
     18     plt.draw()
---> 19     plt.pause(0.0001)

~/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in pause(interval)
    292         if canvas.figure.stale:
    293             canvas.draw_idle()
--> 294         show(block=False)
    295         canvas.start_event_loop(interval)
    296     else:

~/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in show(*args, **kw)
    252     """
    253     global _show
--> 254     return _show(*args, **kw)
    255 
    256 

~/anaconda3/lib/python3.6/site-packages/ipykernel/pylab/backend_inline.py in show(close, block)
     37             display(
     38                 figure_manager.canvas.figure,
---> 39                 metadata=_fetch_figure_metadata(figure_manager.canvas.figure)
     40             )
     41     finally:

~/anaconda3/lib/python3.6/site-packages/IPython/core/display.py in display(include, exclude, metadata, transient, display_id, *objs, **kwargs)
    302             publish_display_data(data=obj, metadata=metadata, **kwargs)
    303         else:
--> 304             format_dict, md_dict = format(obj, include=include, exclude=exclude)
    305             if not format_dict:
    306                 # nothing to display (e.g. _ipython_display_ took over)

~/anaconda3/lib/python3.6/site-packages/IPython/core/formatters.py in format(self, obj, include, exclude)
    178             md = None
    179             try:
--> 180                 data = formatter(obj)
    181             except:
    182                 # FIXME: log the exception

<decorator-gen-9> in __call__(self, obj)

~/anaconda3/lib/python3.6/site-packages/IPython/core/formatters.py in catch_format_error(method, self, *args, **kwargs)
    222     """show traceback on failed format call"""
    223     try:
--> 224         r = method(self, *args, **kwargs)
    225     except NotImplementedError:
    226         # don't warn on NotImplementedErrors

~/anaconda3/lib/python3.6/site-packages/IPython/core/formatters.py in __call__(self, obj)
    339                 pass
    340             else:
--> 341                 return printer(obj)
    342             # Finally look for special method names
    343             method = get_real_method(obj, self.print_method)

~/anaconda3/lib/python3.6/site-packages/IPython/core/pylabtools.py in <lambda>(fig)
    242 
    243     if 'png' in formats:
--> 244         png_formatter.for_type(Figure, lambda fig: print_figure(fig, 'png', **kwargs))
    245     if 'retina' in formats or 'png2x' in formats:
    246         png_formatter.for_type(Figure, lambda fig: retina_figure(fig, **kwargs))

~/anaconda3/lib/python3.6/site-packages/IPython/core/pylabtools.py in print_figure(fig, fmt, bbox_inches, **kwargs)
    126 
    127     bytes_io = BytesIO()
--> 128     fig.canvas.print_figure(bytes_io, **kw)
    129     data = bytes_io.getvalue()
    130     if fmt == 'svg':

~/anaconda3/lib/python3.6/site-packages/matplotlib/backend_bases.py in print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, **kwargs)
   2051                     bbox_artists = kwargs.pop("bbox_extra_artists", None)
   2052                     bbox_inches = self.figure.get_tightbbox(renderer,
-> 2053                             bbox_extra_artists=bbox_artists)
   2054                     pad = kwargs.pop("pad_inches", None)
   2055                     if pad is None:

~/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py in get_tightbbox(self, renderer, bbox_extra_artists)
   2268 
   2269         for a in artists:
-> 2270             bbox = a.get_tightbbox(renderer)
   2271             if bbox is not None and (bbox.width != 0 or bbox.height != 0):
   2272                 bb.append(bbox)

~/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_base.py in get_tightbbox(self, renderer, call_axes_locator, bbox_extra_artists)
   4394 
   4395         for a in bbox_artists:
-> 4396             bbox = a.get_tightbbox(renderer)
   4397             if (bbox is not None and
   4398                     (bbox.width != 0 or bbox.height != 0) and

~/anaconda3/lib/python3.6/site-packages/matplotlib/axis.py in get_tightbbox(self, renderer)
   1138             return
   1139 
-> 1140         ticks_to_draw = self._update_ticks(renderer)
   1141 
   1142         self._update_label_position(renderer)

~/anaconda3/lib/python3.6/site-packages/matplotlib/axis.py in _update_ticks(self, renderer)
   1021 
   1022         interval = self.get_view_interval()
-> 1023         tick_tups = list(self.iter_ticks())  # iter_ticks calls the locator
   1024         if self._smart_bounds and tick_tups:
   1025             # handle inverted limits

~/anaconda3/lib/python3.6/site-packages/matplotlib/axis.py in iter_ticks(self)
    965         Iterate through all of the major and minor ticks.
    966         """
--> 967         majorLocs = self.major.locator()
    968         majorTicks = self.get_major_ticks(len(majorLocs))
    969         self.major.formatter.set_locs(majorLocs)

~/anaconda3/lib/python3.6/site-packages/matplotlib/ticker.py in __call__(self)
   1983     def __call__(self):
   1984         vmin, vmax = self.axis.get_view_interval()
-> 1985         return self.tick_values(vmin, vmax)
   1986 
   1987     def tick_values(self, vmin, vmax):

~/anaconda3/lib/python3.6/site-packages/matplotlib/ticker.py in tick_values(self, vmin, vmax)
   1991         vmin, vmax = mtransforms.nonsingular(
   1992             vmin, vmax, expander=1e-13, tiny=1e-14)
-> 1993         locs = self._raw_ticks(vmin, vmax)
   1994 
   1995         prune = self._prune

~/anaconda3/lib/python3.6/site-packages/matplotlib/ticker.py in _raw_ticks(self, vmin, vmax)
   1930         if self._nbins == 'auto':
   1931             if self.axis is not None:
-> 1932                 nbins = np.clip(self.axis.get_tick_space(),
   1933                                 max(1, self._min_n_ticks - 1), 9)
   1934             else:

~/anaconda3/lib/python3.6/site-packages/matplotlib/axis.py in get_tick_space(self)
   2537         ends = self.axes.transAxes.transform([[0, 0], [0, 1]])
   2538         length = ((ends[1][1] - ends[0][1]) / self.axes.figure.dpi) * 72
-> 2539         tick = self._get_tick(True)
   2540         # Having a spacing of at least 2 just looks good.
   2541         size = tick.label1.get_size() * 2.0

~/anaconda3/lib/python3.6/site-packages/matplotlib/axis.py in _get_tick(self, major)
   2192         else:
   2193             tick_kw = self._minor_tick_kw
-> 2194         return YTick(self.axes, 0, '', major=major, **tick_kw)
   2195 
   2196     def _get_label(self):

~/anaconda3/lib/python3.6/site-packages/matplotlib/axis.py in __init__(self, axes, loc, label, size, width, color, tickdir, pad, labelsize, labelcolor, zorder, gridOn, tick1On, tick2On, label1On, label2On, major, labelrotation, grid_color, grid_linestyle, grid_linewidth, grid_alpha, **kw)
    178         self.label1 = self._get_text1()
    179         self.label = self.label1  # legacy name
--> 180         self.label2 = self._get_text2()
    181 
    182         self.gridOn = gridOn

~/anaconda3/lib/python3.6/site-packages/matplotlib/axis.py in _get_text2(self)
    571         'Get the default Text instance'
    572         # x in axes coords, y in data coords
--> 573         trans, vert, horiz = self._get_text2_transform()
    574         t = mtext.Text(
    575             x=1, y=0,

~/anaconda3/lib/python3.6/site-packages/matplotlib/axis.py in _get_text2_transform(self)
    536 
    537     def _get_text2_transform(self):
--> 538         return self.axes.get_yaxis_text2_transform(self._pad)
    539 
    540     def apply_tickdir(self, tickdir):

~/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_base.py in get_yaxis_text2_transform(self, pad_points)
    812         return (self.get_yaxis_transform(which='tick2') +
    813                 mtransforms.ScaledTranslation(pad_points / 72, 0,
--> 814                                               self.figure.dpi_scale_trans),
    815                 labels_align, "left")
    816 

~/anaconda3/lib/python3.6/site-packages/matplotlib/transforms.py in __add__(self, other)
   1271         """
   1272         if isinstance(other, Transform):
-> 1273             return composite_transform_factory(self, other)
   1274         raise TypeError(
   1275             "Can not add Transform to object of type '%s'" % type(other))

~/anaconda3/lib/python3.6/site-packages/matplotlib/transforms.py in composite_transform_factory(a, b)
   2558     elif isinstance(a, Affine2D) and isinstance(b, Affine2D):
   2559         return CompositeAffine2D(a, b)
-> 2560     return CompositeGenericTransform(a, b)
   2561 
   2562 

~/anaconda3/lib/python3.6/site-packages/matplotlib/transforms.py in __init__(self, a, b, **kwargs)
   2371         self._a = a
   2372         self._b = b
-> 2373         self.set_children(a, b)
   2374 
   2375     is_affine = property(lambda self: self._a.is_affine and self._b.is_affine)

KeyboardInterrupt: 
In [13]:
batchX_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_backprop_length])
batchY_placeholder = tf.placeholder(tf.int32, [batch_size, truncated_backprop_length])

init_state = tf.placeholder(tf.float32, [batch_size, state_size])

W2 = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)
b2 = tf.Variable(np.zeros((1,num_classes)), dtype=tf.float32)
In [15]:
# Unpack columns
inputs_series = tf.split(axis=1, num_or_size_splits=truncated_backprop_length, value=batchX_placeholder)
labels_series = tf.unstack(batchY_placeholder, axis=1)

# Forward passes
cell = tf.nn.rnn_cell.BasicRNNCell(state_size)
states_series, current_state = tf.nn.static_rnn(cell, inputs_series, init_state)
In [16]:
logits_series = [tf.matmul(state, W2) + b2 for state in states_series] #Broadcasted addition
predictions_series = [tf.nn.softmax(logits) for logits in logits_series]

losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) for logits, labels in zip(logits_series,labels_series)]
total_loss = tf.reduce_mean(losses)

train_step = tf.train.AdagradOptimizer(0.3).minimize(total_loss)
In [17]:
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    plt.ion()
    plt.figure()
    plt.show()
    loss_list = []

    for epoch_idx in range(num_epochs):
        x,y = generateData()
        _current_state = np.zeros((batch_size, state_size))

        print("New data, epoch", epoch_idx)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * truncated_backprop_length
            end_idx = start_idx + truncated_backprop_length

            batchX = x[:,start_idx:end_idx]
            batchY = y[:,start_idx:end_idx]

            _total_loss, _train_step, _current_state, _predictions_series = sess.run(
                [total_loss, train_step, current_state, predictions_series],
                feed_dict={
                    batchX_placeholder:batchX,
                    batchY_placeholder:batchY,
                    init_state:_current_state
                })

            loss_list.append(_total_loss)

            if batch_idx%100 == 0:
                print("Step",batch_idx, "Loss", _total_loss)
                plot(loss_list, _predictions_series, batchX, batchY)

plt.ioff()
plt.show()
<Figure size 432x288 with 0 Axes>
New data, epoch 0
Step 0 Loss 0.6982855
Step 100 Loss 0.013246939
Step 200 Loss 0.008279111
Step 300 Loss 0.0034963607
Step 400 Loss 0.00208223
Step 500 Loss 0.0017952062
Step 600 Loss 0.001557593
New data, epoch 1
Step 0 Loss 0.15199867
Step 100 Loss 0.0011497578
Step 200 Loss 0.0012072121
Step 300 Loss 0.0007493489
Step 400 Loss 0.000710824
Step 500 Loss 0.0007313626
Step 600 Loss 0.0006299488
New data, epoch 2
Step 0 Loss 0.26751977
Step 100 Loss 0.00074356975
Step 200 Loss 0.00065260654
Step 300 Loss 0.0006024815
Step 400 Loss 0.0005332958
Step 500 Loss 0.0005430037
Step 600 Loss 0.0005264984
New data, epoch 3
Step 0 Loss 0.116931506
Step 100 Loss 0.00038719067
Step 200 Loss 0.00039393836
Step 300 Loss 0.00032965202
Step 400 Loss 0.0003003399
Step 500 Loss 0.00034821747
Step 600 Loss 0.0003380571
New data, epoch 4
Step 0 Loss 0.21497773
Step 100 Loss 0.00032263552
Step 200 Loss 0.00029311565
Step 300 Loss 0.00031140816
Step 400 Loss 0.00057091325
Step 500 Loss 0.0002645592
Step 600 Loss 0.00026720585
New data, epoch 5
Step 0 Loss 0.15341614
Step 100 Loss 0.00039618427
Step 200 Loss 0.00029840012
Step 300 Loss 0.0002890405
Step 400 Loss 0.000247505
Step 500 Loss 0.00020743022
Step 600 Loss 0.0002527078
New data, epoch 6
Step 0 Loss 0.25680906
Step 100 Loss 0.00031917397
Step 200 Loss 0.0003200168
Step 300 Loss 0.0002946855
Step 400 Loss 0.0002510799
Step 500 Loss 0.00024263379
Step 600 Loss 0.00025719928
New data, epoch 7
Step 0 Loss 0.4298251
Step 100 Loss 0.00027233074
Step 200 Loss 0.00024521316
Step 300 Loss 0.00024086528
Step 400 Loss 0.00023651737
Step 500 Loss 0.00027099554
Step 600 Loss 0.00018952061
New data, epoch 8
Step 0 Loss 0.24153669
Step 100 Loss 0.00018646351
Step 200 Loss 0.00014600725
Step 300 Loss 0.00017605706
Step 400 Loss 0.00017437991
Step 500 Loss 0.00012851991
Step 600 Loss 0.00015479886
New data, epoch 9
Step 0 Loss 0.15876548
Step 100 Loss 0.00015596561
Step 200 Loss 0.00013483621
Step 300 Loss 0.00014950764
Step 400 Loss 0.0001469723
Step 500 Loss 0.00013401255
Step 600 Loss 0.00013240283
New data, epoch 10
Step 0 Loss 0.23534058
Step 100 Loss 0.00011632482
Step 200 Loss 0.00013350137
Step 300 Loss 0.00013189923
Step 400 Loss 0.0001457403
Step 500 Loss 0.00010924783
Step 600 Loss 0.00012994547
New data, epoch 11
Step 0 Loss 0.32693768
Step 100 Loss 0.00013656495
Step 200 Loss 0.00013764921
Step 300 Loss 0.00015430438
Step 400 Loss 0.00013483367
Step 500 Loss 0.0001254723
Step 600 Loss 0.000131059
New data, epoch 12
Step 0 Loss 0.56939805
Step 100 Loss 0.00014214417
Step 200 Loss 0.00010522907
Step 300 Loss 0.00014844668
Step 400 Loss 0.00013075993
Step 500 Loss 0.000115495735
Step 600 Loss 0.00012718976
New data, epoch 13
Step 0 Loss 0.2304152
Step 100 Loss 0.0001468133
Step 200 Loss 0.00013610053
Step 300 Loss 0.00014839655
Step 400 Loss 0.00011869021
Step 500 Loss 0.00013058132
Step 600 Loss 0.000116928444
New data, epoch 14
Step 0 Loss 0.43813437
Step 100 Loss 0.00015788985
Step 200 Loss 0.00015327625
Step 300 Loss 0.00012214221
Step 400 Loss 0.00010949075
Step 500 Loss 0.000115712384
Step 600 Loss 9.055587e-05
New data, epoch 15
Step 0 Loss 0.38702154
Step 100 Loss 8.97052e-05
Step 200 Loss 8.702542e-05
Step 300 Loss 0.000103681195
Step 400 Loss 9.1620415e-05
Step 500 Loss 6.9231304e-05
Step 600 Loss 9.178152e-05
New data, epoch 16
Step 0 Loss 0.18606181
Step 100 Loss 9.198912e-05
Step 200 Loss 9.413579e-05
Step 300 Loss 8.587715e-05
Step 400 Loss 9.597168e-05
Step 500 Loss 0.00010871191
Step 600 Loss 8.89896e-05
New data, epoch 17
Step 0 Loss 0.16330035
Step 100 Loss 0.00010738857
Step 200 Loss 9.1690534e-05
Step 300 Loss 8.744978e-05
Step 400 Loss 9.423361e-05
Step 500 Loss 9.08644e-05
Step 600 Loss 7.7074765e-05
New data, epoch 18
Step 0 Loss 0.15794377
Step 100 Loss 8.988769e-05
Step 200 Loss 0.000105293046
Step 300 Loss 9.0536734e-05
Step 400 Loss 8.9812114e-05
Step 500 Loss 7.1184775e-05
Step 600 Loss 9.3409464e-05
New data, epoch 19
Step 0 Loss 0.32232207
Step 100 Loss 9.126863e-05
Step 200 Loss 9.5913456e-05
Step 300 Loss 8.7263375e-05
Step 400 Loss 7.51517e-05
Step 500 Loss 8.377896e-05
Step 600 Loss 7.816603e-05
In [18]:
batchX_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_backprop_length])
batchY_placeholder = tf.placeholder(tf.int32, [batch_size, truncated_backprop_length])
In [19]:
cell_state = tf.placeholder(tf.float32, [batch_size, state_size])
hidden_state = tf.placeholder(tf.float32, [batch_size, state_size])
init_state = tf.nn.rnn_cell.LSTMStateTuple(cell_state, hidden_state)
In [20]:
W2 = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)
b2 = tf.Variable(np.zeros((1,num_classes)), dtype=tf.float32)
In [21]:
# Unpack columns
inputs_series = tf.split(batchX_placeholder, truncated_backprop_length,1)
labels_series = tf.unstack(batchY_placeholder, axis=1)
In [23]:
# Forward passes
cell = tf.nn.rnn_cell.BasicLSTMCell(state_size, state_is_tuple=True)
states_series, current_state = tf.nn.static_rnn(cell, inputs_series, init_state)

logits_series = [tf.matmul(state, W2) + b2 for state in states_series] #Broadcasted addition
predictions_series = [tf.nn.softmax(logits) for logits in logits_series]

losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) for logits, labels in zip(logits_series,labels_series)]
total_loss = tf.reduce_mean(losses)

train_step = tf.train.AdagradOptimizer(0.3).minimize(total_loss)
In [24]:
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    plt.ion()
    plt.figure()
    plt.show()
    loss_list = []

    for epoch_idx in range(num_epochs):
        x,y = generateData()
        _current_cell_state = np.zeros((batch_size, state_size))
        _current_hidden_state = np.zeros((batch_size, state_size))

        print("New data, epoch", epoch_idx)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * truncated_backprop_length
            end_idx = start_idx + truncated_backprop_length

            batchX = x[:,start_idx:end_idx]
            batchY = y[:,start_idx:end_idx]

            _total_loss, _train_step, _current_state, _predictions_series = sess.run(
                [total_loss, train_step, current_state, predictions_series],
                feed_dict={
                    batchX_placeholder: batchX,
                    batchY_placeholder: batchY,
                    cell_state: _current_cell_state,
                    hidden_state: _current_hidden_state

                })

            _current_cell_state, _current_hidden_state = _current_state

            loss_list.append(_total_loss)

            if batch_idx%100 == 0:
                print("Step",batch_idx, "Batch loss", _total_loss)
                plot(loss_list, _predictions_series, batchX, batchY)

plt.ioff()
plt.show()
<Figure size 432x288 with 0 Axes>
New data, epoch 0
Step 0 Batch loss 0.69308263
Step 100 Batch loss 0.5740948
Step 200 Batch loss 0.5479877
Step 300 Batch loss 0.09256464
Step 400 Batch loss 0.023889415
Step 500 Batch loss 0.013113284
Step 600 Batch loss 0.007990479
New data, epoch 1
Step 0 Batch loss 0.47455165
Step 100 Batch loss 0.005530475
Step 200 Batch loss 0.004880399
Step 300 Batch loss 0.003307631
Step 400 Batch loss 0.0030651344
Step 500 Batch loss 0.002255631
Step 600 Batch loss 0.0020037685
New data, epoch 2
Step 0 Batch loss 0.900985
Step 100 Batch loss 0.0017243241
Step 200 Batch loss 0.0014562543
Step 300 Batch loss 0.0015029077
Step 400 Batch loss 0.0014899417
Step 500 Batch loss 0.0010047215
Step 600 Batch loss 0.0011756229
New data, epoch 3
Step 0 Batch loss 0.70176333
Step 100 Batch loss 0.0013060072
Step 200 Batch loss 0.00081299344
Step 300 Batch loss 0.000978227
Step 400 Batch loss 0.0011675187
Step 500 Batch loss 0.0007133144
Step 600 Batch loss 0.0007995808
New data, epoch 4
Step 0 Batch loss 0.5927401
Step 100 Batch loss 0.00066489616
Step 200 Batch loss 0.00058677635
Step 300 Batch loss 0.0006112889
Step 400 Batch loss 0.0009566307
Step 500 Batch loss 0.0005825069
Step 600 Batch loss 0.0004714114
New data, epoch 5
Step 0 Batch loss 0.5807799
Step 100 Batch loss 0.00074211153
Step 200 Batch loss 0.0005493958
Step 300 Batch loss 0.00051210023
Step 400 Batch loss 0.00052385713
Step 500 Batch loss 0.00047664504
Step 600 Batch loss 0.000458911
New data, epoch 6
Step 0 Batch loss 0.43799642
Step 100 Batch loss 0.00057161227
Step 200 Batch loss 0.00044315468
Step 300 Batch loss 0.000557181
Step 400 Batch loss 0.00046680553
Step 500 Batch loss 0.00041200794
Step 600 Batch loss 0.00037649006
New data, epoch 7
Step 0 Batch loss 0.839055
Step 100 Batch loss 0.00029412046
Step 200 Batch loss 0.00031541998
Step 300 Batch loss 0.00049506663
Step 400 Batch loss 0.00031525813
Step 500 Batch loss 0.00037317592
Step 600 Batch loss 0.0003230826
New data, epoch 8
Step 0 Batch loss 0.5460883
Step 100 Batch loss 0.00031723242
Step 200 Batch loss 0.00043820654
Step 300 Batch loss 0.0002893388
Step 400 Batch loss 0.00028647718
Step 500 Batch loss 0.00030241735
Step 600 Batch loss 0.00038411538
New data, epoch 9
Step 0 Batch loss 0.70026326
Step 100 Batch loss 0.00028915523
Step 200 Batch loss 0.00030398692
Step 300 Batch loss 0.00028282497
Step 400 Batch loss 0.00030019667
Step 500 Batch loss 0.00034289065
Step 600 Batch loss 0.00024168826
New data, epoch 10
Step 0 Batch loss 1.0750598
Step 100 Batch loss 0.00035293392
Step 200 Batch loss 0.0003921036
Step 300 Batch loss 0.00047222842
Step 400 Batch loss 0.00044746776
Step 500 Batch loss 0.00046427405
Step 600 Batch loss 0.00033717812
New data, epoch 11
Step 0 Batch loss 0.27755404
Step 100 Batch loss 0.00033841887
Step 200 Batch loss 0.00037957338
Step 300 Batch loss 0.00037384525
Step 400 Batch loss 0.00032976523
Step 500 Batch loss 0.0002584302
Step 600 Batch loss 0.0003192407
New data, epoch 12
Step 0 Batch loss 0.45708573
Step 100 Batch loss 0.00023744849
Step 200 Batch loss 0.0003030932
Step 300 Batch loss 0.0002987323
Step 400 Batch loss 0.00024694984
Step 500 Batch loss 0.00022848482
Step 600 Batch loss 0.00030943842
New data, epoch 13
Step 0 Batch loss 1.1079967
Step 100 Batch loss 0.00023018268
Step 200 Batch loss 0.00031940022
Step 300 Batch loss 0.0003810525
Step 400 Batch loss 0.00023887785
Step 500 Batch loss 0.00029132614
Step 600 Batch loss 0.0002699033
New data, epoch 14
Step 0 Batch loss 0.5251927
Step 100 Batch loss 0.00022314918
Step 200 Batch loss 0.00025014437
Step 300 Batch loss 0.00021712527
Step 400 Batch loss 0.00025150322
Step 500 Batch loss 0.000255053
Step 600 Batch loss 0.00028761744
New data, epoch 15
Step 0 Batch loss 0.85371715
Step 100 Batch loss 0.0002231724
Step 200 Batch loss 0.00036116407
Step 300 Batch loss 0.00022946882
Step 400 Batch loss 0.00023825395
Step 500 Batch loss 0.00017456748
Step 600 Batch loss 0.00019002012
New data, epoch 16
Step 0 Batch loss 1.2364419
Step 100 Batch loss 0.00022184683
Step 200 Batch loss 0.00022059912
Step 300 Batch loss 0.00022933568
Step 400 Batch loss 0.000319852
Step 500 Batch loss 0.00032730534
Step 600 Batch loss 0.00019379245
New data, epoch 17
Step 0 Batch loss 1.0195975
Step 100 Batch loss 0.00022434742
Step 200 Batch loss 0.00023454311
Step 300 Batch loss 0.00021691006
Step 400 Batch loss 0.00023509815
Step 500 Batch loss 0.0002457771
Step 600 Batch loss 0.00017381085
New data, epoch 18
Step 0 Batch loss 0.5653013
Step 100 Batch loss 0.00020220432
Step 200 Batch loss 0.00023946866
Step 300 Batch loss 0.0002053471
Step 400 Batch loss 0.00016366476
Step 500 Batch loss 0.00022538993
Step 600 Batch loss 0.00018603388
New data, epoch 19
Step 0 Batch loss 0.31418866
Step 100 Batch loss 0.00031793606
Step 200 Batch loss 0.00021100588
Step 300 Batch loss 0.0002081261
Step 400 Batch loss 0.00015878538
Step 500 Batch loss 0.00018147459
Step 600 Batch loss 0.00018880297

MultiLayered LSTM

_current_state = np.zeros((num_layers, 2, batch_size, state_size)) #2 = cells and hidden state

_total_loss, _train_step, _current_state, _predictions_series = sess.run( [total_loss, train_step, current_state, predictions_series], feed_dict={ batchX_placeholder: batchX, batchY_placeholder: batchY, init_state: _current_state })

init_state = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])

state_per_layer_list = tf.unpack(init_state, axis=0) rnn_tuple_state = tuple( [tf.nn.rnn_cell.LSTMStateTuple(state_per_layer_list[idx][0], state_per_layer_list[idx][1]) for idx in range(num_layers)] )

# Forward passes cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True) cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True) states_series, current_state = tf.nn.rnn(cell, inputs_series, initial_state=rnn_tuple_state)

In [7]:
num_epochs = 100
total_series_length = 50000
truncated_backprop_length = 15
state_size = 4
num_classes = 2
echo_step = 3
batch_size = 5
num_batches = total_series_length//batch_size//truncated_backprop_length
num_layers = 3
In [8]:
init_state = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
In [9]:
state_per_layer_list = tf.unstack(init_state, axis=0)
rnn_tuple_state = tuple(
    [tf.nn.rnn_cell.LSTMStateTuple(state_per_layer_list[idx][0], state_per_layer_list[idx][1])
     for idx in range(num_layers)]
)

W2 = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)
b2 = tf.Variable(np.zeros((1,num_classes)), dtype=tf.float32)
In [10]:
# Unpack columns
inputs_series = tf.split(batchX_placeholder, truncated_backprop_length, 1)
labels_series = tf.unstack(batchY_placeholder, axis=1)
tf.nn.static_rnn(cell, inputs_series, initial_state=rnn_tuple_state)
In [13]:
# Forward passes
cells = []
for l in range(num_layers):
    cells.append(tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True, reuse=True))
cell = tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True)
states_series, current_state = tf.nn.static_rnn(cell, inputs_series, initial_state=rnn_tuple_state)

logits_series = [tf.matmul(state, W2) + b2 for state in states_series] #Broadcasted addition
predictions_series = [tf.nn.softmax(logits) for logits in logits_series]

losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) for logits, labels in zip(logits_series,labels_series)]
total_loss = tf.reduce_mean(losses)

train_step = tf.train.AdagradOptimizer(0.3).minimize(total_loss)
In [16]:
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    plt.ion()
    plt.figure()
    plt.show()
    loss_list = []

    for epoch_idx in range(num_epochs):
        x,y = generateData()

        _current_state = np.zeros((num_layers, 2, batch_size, state_size))

        print("New data, epoch", epoch_idx)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * truncated_backprop_length
            end_idx = start_idx + truncated_backprop_length

            batchX = x[:,start_idx:end_idx]
            batchY = y[:,start_idx:end_idx]

            _total_loss, _train_step, _current_state, _predictions_series = sess.run(
                [total_loss, train_step, current_state, predictions_series],
                feed_dict={
                    batchX_placeholder: batchX,
                    batchY_placeholder: batchY,
                    init_state: _current_state
                })


            loss_list.append(_total_loss)

            if batch_idx%100 == 0:
                print("Step",batch_idx, "Batch loss", _total_loss)
                plot(loss_list, _predictions_series, batchX, batchY)

plt.ioff()
plt.show()
<Figure size 432x288 with 0 Axes>
New data, epoch 0
Step 0 Batch loss 0.6932568
Step 100 Batch loss 0.68839025
Step 200 Batch loss 0.59013796
Step 300 Batch loss 0.48798788
Step 400 Batch loss 0.48281977
Step 500 Batch loss 0.06706321
Step 600 Batch loss 0.013845464
New data, epoch 1
Step 0 Batch loss 0.47437552
Step 100 Batch loss 0.00432524
Step 200 Batch loss 0.002864269
Step 300 Batch loss 0.0021214883
Step 400 Batch loss 0.001675802
Step 500 Batch loss 0.0015142768
Step 600 Batch loss 0.0012105916
New data, epoch 2
Step 0 Batch loss 0.70432633
Step 100 Batch loss 0.0029130569
Step 200 Batch loss 0.0020726628
Step 300 Batch loss 0.0015323089
Step 400 Batch loss 0.0015713219
Step 500 Batch loss 0.001054751
Step 600 Batch loss 0.0010608693
New data, epoch 3
Step 0 Batch loss 0.99548006
Step 100 Batch loss 0.0010272858
Step 200 Batch loss 0.00090478867
Step 300 Batch loss 0.0008401722
Step 400 Batch loss 0.00079807005
Step 500 Batch loss 0.00065314595
Step 600 Batch loss 0.000578491
New data, epoch 4
Step 0 Batch loss 0.58743143
Step 100 Batch loss 0.0006878531
Step 200 Batch loss 0.0006120762
Step 300 Batch loss 0.0005602129
Step 400 Batch loss 0.0005100224
Step 500 Batch loss 0.00042156264
Step 600 Batch loss 0.0004379494
New data, epoch 5
Step 0 Batch loss 0.7926665
Step 100 Batch loss 0.0005900583
Step 200 Batch loss 0.0004361567
Step 300 Batch loss 0.0004281046
Step 400 Batch loss 0.00038773345
Step 500 Batch loss 0.00042448918
Step 600 Batch loss 0.0003825442
New data, epoch 6
Step 0 Batch loss 0.5242294
Step 100 Batch loss 0.00038974723
Step 200 Batch loss 0.00037480364
Step 300 Batch loss 0.0003642648
Step 400 Batch loss 0.00033455808
Step 500 Batch loss 0.0002984124
Step 600 Batch loss 0.0002803059
New data, epoch 7
Step 0 Batch loss 0.325301
Step 100 Batch loss 0.00054258626
Step 200 Batch loss 0.00039540292
Step 300 Batch loss 0.00037772616
Step 400 Batch loss 0.00035025715
Step 500 Batch loss 0.00033866183
Step 600 Batch loss 0.00031203398
New data, epoch 8
Step 0 Batch loss 0.44868982
Step 100 Batch loss 0.00033042583
Step 200 Batch loss 0.00032033553
Step 300 Batch loss 0.0002754921
Step 400 Batch loss 0.00027170617
Step 500 Batch loss 0.00026101424
Step 600 Batch loss 0.0002491086
New data, epoch 9
Step 0 Batch loss 0.6554325
Step 100 Batch loss 0.00032347275
Step 200 Batch loss 0.00028352722
Step 300 Batch loss 0.00029370037
Step 400 Batch loss 0.0002534043
Step 500 Batch loss 0.00025913646
Step 600 Batch loss 0.00023474351
New data, epoch 10
Step 0 Batch loss 0.6652185
Step 100 Batch loss 0.00030487293
Step 200 Batch loss 0.00026815187
Step 300 Batch loss 0.00026030364
Step 400 Batch loss 0.00021314705
Step 500 Batch loss 0.00021548073
Step 600 Batch loss 0.00021410892
New data, epoch 11
Step 0 Batch loss 0.51591164
Step 100 Batch loss 0.00025008962
Step 200 Batch loss 0.0002471315
Step 300 Batch loss 0.00021525307
Step 400 Batch loss 0.00019469531
Step 500 Batch loss 0.00018908482
Step 600 Batch loss 0.00019842475
New data, epoch 12
Step 0 Batch loss 0.6568196
Step 100 Batch loss 0.00025814914
Step 200 Batch loss 0.00020639245
Step 300 Batch loss 0.00026313929
Step 400 Batch loss 0.00021558357
Step 500 Batch loss 0.00020683452
Step 600 Batch loss 0.00018648386
New data, epoch 13
Step 0 Batch loss 0.6128642
Step 100 Batch loss 0.00020136937
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-16-59557012e2c4> in <module>
     33             if batch_idx%100 == 0:
     34                 print("Step",batch_idx, "Batch loss", _total_loss)
---> 35                 plot(loss_list, _predictions_series, batchX, batchY)
     36 
     37 plt.ioff()

<ipython-input-15-a7a221093bf1> in plot(loss_list, predictions_series, batchX, batchY)
     16         plt.bar(left_offset, single_output_series * 0.3, width=1, color="green")
     17 
---> 18     plt.draw()
     19     plt.pause(0.0001)

~/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in draw()
    681         fig.canvas.draw_idle()
    682     """
--> 683     get_current_fig_manager().canvas.draw_idle()
    684 
    685 

~/anaconda3/lib/python3.6/site-packages/matplotlib/backend_bases.py in draw_idle(self, *args, **kwargs)
   1897         if not self._is_idle_drawing:
   1898             with self._idle_draw_cntx():
-> 1899                 self.draw(*args, **kwargs)
   1900 
   1901     def draw_cursor(self, event):

~/anaconda3/lib/python3.6/site-packages/matplotlib/backends/backend_agg.py in draw(self)
    400         toolbar = self.toolbar
    401         try:
--> 402             self.figure.draw(self.renderer)
    403             # A GUI class may be need to update a window using this draw, so
    404             # don't forget to call the superclass.

~/anaconda3/lib/python3.6/site-packages/matplotlib/artist.py in draw_wrapper(artist, renderer, *args, **kwargs)
     48                 renderer.start_filter()
     49 
---> 50             return draw(artist, renderer, *args, **kwargs)
     51         finally:
     52             if artist.get_agg_filter() is not None:

~/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py in draw(self, renderer)
   1647 
   1648             mimage._draw_list_compositing_images(
-> 1649                 renderer, self, artists, self.suppressComposite)
   1650 
   1651             renderer.close_group('figure')

~/anaconda3/lib/python3.6/site-packages/matplotlib/image.py in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
    136     if not_composite or not has_images:
    137         for a in artists:
--> 138             a.draw(renderer)
    139     else:
    140         # Composite any adjacent images together

~/anaconda3/lib/python3.6/site-packages/matplotlib/artist.py in draw_wrapper(artist, renderer, *args, **kwargs)
     48                 renderer.start_filter()
     49 
---> 50             return draw(artist, renderer, *args, **kwargs)
     51         finally:
     52             if artist.get_agg_filter() is not None:

~/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_base.py in draw(self, renderer, inframe)
   2626             renderer.stop_rasterizing()
   2627 
-> 2628         mimage._draw_list_compositing_images(renderer, self, artists)
   2629 
   2630         renderer.close_group('axes')

~/anaconda3/lib/python3.6/site-packages/matplotlib/image.py in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
    136     if not_composite or not has_images:
    137         for a in artists:
--> 138             a.draw(renderer)
    139     else:
    140         # Composite any adjacent images together

~/anaconda3/lib/python3.6/site-packages/matplotlib/artist.py in draw_wrapper(artist, renderer, *args, **kwargs)
     48                 renderer.start_filter()
     49 
---> 50             return draw(artist, renderer, *args, **kwargs)
     51         finally:
     52             if artist.get_agg_filter() is not None:

~/anaconda3/lib/python3.6/site-packages/matplotlib/patches.py in draw(self, renderer)
    533             renderer = PathEffectRenderer(self.get_path_effects(), renderer)
    534 
--> 535         renderer.draw_path(gc, tpath, affine, rgbFace)
    536 
    537         gc.restore()

~/anaconda3/lib/python3.6/site-packages/matplotlib/backends/backend_agg.py in draw_path(self, gc, path, transform, rgbFace)
    147         else:
    148             try:
--> 149                 self._renderer.draw_path(gc, path, transform, rgbFace)
    150             except OverflowError:
    151                 raise OverflowError("Exceeded cell block limit (set "

~/anaconda3/lib/python3.6/site-packages/matplotlib/transforms.py in __array__(self, *args, **kwargs)
   1727         self._inverted = None
   1728 
-> 1729     def __array__(self, *args, **kwargs):
   1730         # optimises the access of the transform matrix vs the superclass
   1731         return self.get_matrix()

KeyboardInterrupt: 

logits = tf.matmul(states_series, W2) + b2 #Broadcasted addition labels = tf.reshape(batchY_placeholder, [-1])

logits_series = tf.unpack(tf.reshape(logits, [batch_size, truncated_backprop_length, num_classes]), axis=1) predictions_series = [tf.nn.softmax(logit) for logit in logits_list]

losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels)

In [6]:
num_epochs = 100
total_series_length = 50000
truncated_backprop_length = 15
state_size = 4
num_classes = 2
echo_step = 3
batch_size = 5
num_batches = total_series_length//batch_size//truncated_backprop_length
num_layers = 3
In [7]:
batchX_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_backprop_length])
batchY_placeholder = tf.placeholder(tf.int32, [batch_size, truncated_backprop_length])

init_state = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
In [11]:
state_per_layer_list = tf.unstack(init_state, axis=0)
rnn_tuple_state = tuple(
    [tf.nn.rnn_cell.LSTMStateTuple(state_per_layer_list[idx][0], state_per_layer_list[idx][1])
     for idx in range(num_layers)]
)

W2 = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)
b2 = tf.Variable(np.zeros((1,num_classes)), dtype=tf.float32)
In [13]:
cells = []
for l in range(num_layers):
    cells.append(tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True))
cell = tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True)
states_series, current_state = tf.nn.dynamic_rnn(cell, tf.expand_dims(batchX_placeholder, -1), initial_state=rnn_tuple_state)
states_series = tf.reshape(states_series, [-1, state_size])
In [14]:
logits = tf.matmul(states_series, W2) + b2 #Broadcasted addition
labels = tf.reshape(batchY_placeholder, [-1])

logits_series = tf.unstack(tf.reshape(logits, [batch_size, truncated_backprop_length, num_classes]), axis=1)
predictions_series = [tf.nn.softmax(logit) for logit in logits_series]

losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
total_loss = tf.reduce_mean(losses)

train_step = tf.train.AdagradOptimizer(0.3).minimize(total_loss)
In [15]:
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    plt.ion()
    plt.figure()
    plt.show()
    loss_list = []

    for epoch_idx in range(num_epochs):
        x,y = generateData()

        _current_state = np.zeros((num_layers, 2, batch_size, state_size))

        print("New data, epoch", epoch_idx)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * truncated_backprop_length
            end_idx = start_idx + truncated_backprop_length

            batchX = x[:,start_idx:end_idx]
            batchY = y[:,start_idx:end_idx]

            _total_loss, _train_step, _current_state, _predictions_series = sess.run(
                [total_loss, train_step, current_state, predictions_series],
                feed_dict={
                    batchX_placeholder: batchX,
                    batchY_placeholder: batchY,
                    init_state: _current_state
                })


            loss_list.append(_total_loss)

            if batch_idx%100 == 0:
                print("Step",batch_idx, "Batch loss", _total_loss)
                plot(loss_list, _predictions_series, batchX, batchY)

plt.ioff()
plt.show()W = tf.Variable(np.random.rand(state_size+1, state_size), dtype=tf.float32)
b = tf.Variable(np.zeros((1,state_size)), dtype=tf.float32)

W2 = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)
b2 = tf.Variable(np.zeros((1,num_classes)), dtype=tf.float32)
WARNING:tensorflow:From /root/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/tf_should_use.py:189: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
<Figure size 432x288 with 0 Axes>
New data, epoch 0
Step 0 Batch loss 0.69462144
Step 100 Batch loss 0.6403059
Step 200 Batch loss 0.5330861
Step 300 Batch loss 0.4560477
Step 400 Batch loss 0.24546789
Step 500 Batch loss 0.011403675
Step 600 Batch loss 0.004564733
New data, epoch 1
Step 0 Batch loss 0.59189963
Step 100 Batch loss 0.0022266419
Step 200 Batch loss 0.0015301491
Step 300 Batch loss 0.0013130861
Step 400 Batch loss 0.001014766
Step 500 Batch loss 0.0007172469
Step 600 Batch loss 0.00066570926
New data, epoch 2
Step 0 Batch loss 0.3979042
Step 100 Batch loss 0.00081822515
Step 200 Batch loss 0.000649562
Step 300 Batch loss 0.000654284
Step 400 Batch loss 0.0005989049
Step 500 Batch loss 0.00042402736
Step 600 Batch loss 0.000498191
New data, epoch 3
Step 0 Batch loss 0.28890017
Step 100 Batch loss 0.00082403404
Step 200 Batch loss 0.0004770147
Step 300 Batch loss 0.0004675912
Step 400 Batch loss 0.00046726296
Step 500 Batch loss 0.00045479182
Step 600 Batch loss 0.000337058
New data, epoch 4
Step 0 Batch loss 0.4139736
Step 100 Batch loss 0.00034858542
Step 200 Batch loss 0.00029061848
Step 300 Batch loss 0.00025067956
Step 400 Batch loss 0.00024369471
Step 500 Batch loss 0.0002506139
Step 600 Batch loss 0.00024406282
New data, epoch 5
Step 0 Batch loss 0.39618513
Step 100 Batch loss 0.00024347995
Step 200 Batch loss 0.00025281313
Step 300 Batch loss 0.00020539566
Step 400 Batch loss 0.00021690798
Step 500 Batch loss 0.00020348333
Step 600 Batch loss 0.00018319
New data, epoch 6
Step 0 Batch loss 0.66694725
Step 100 Batch loss 0.00020019071
Step 200 Batch loss 0.00021230054
Step 300 Batch loss 0.00017080053
Step 400 Batch loss 0.00020071829
Step 500 Batch loss 0.00018426323
Step 600 Batch loss 0.0001786654
New data, epoch 7
Step 0 Batch loss 0.70549023
Step 100 Batch loss 0.00019922701
Step 200 Batch loss 0.00017593557
Step 300 Batch loss 0.00016970292
Step 400 Batch loss 0.00017887873
Step 500 Batch loss 0.0001442471
Step 600 Batch loss 0.00015375964
New data, epoch 8
Step 0 Batch loss 0.47492212
Step 100 Batch loss 0.00017672102
Step 200 Batch loss 0.00014205706
Step 300 Batch loss 0.00014955053
Step 400 Batch loss 0.00014298932
Step 500 Batch loss 0.00014204219
Step 600 Batch loss 0.00013350036
New data, epoch 9
Step 0 Batch loss 0.7015376
Step 100 Batch loss 0.00020066673
Step 200 Batch loss 0.00015156831
Step 300 Batch loss 0.00013563372
Step 400 Batch loss 0.00013222758
Step 500 Batch loss 0.00015432577
Step 600 Batch loss 0.00014407876
New data, epoch 10
Step 0 Batch loss 0.57007074
Step 100 Batch loss 0.00020309168
Step 200 Batch loss 0.00013806956
Step 300 Batch loss 0.00015679267
Step 400 Batch loss 0.00014979264
Step 500 Batch loss 0.00013528438
Step 600 Batch loss 0.00012954207
New data, epoch 11
Step 0 Batch loss 0.77198404
Step 100 Batch loss 0.00019985987
Step 200 Batch loss 0.00018815658
Step 300 Batch loss 0.00015372517
Step 400 Batch loss 0.00015308762
Step 500 Batch loss 0.00012647577
Step 600 Batch loss 0.00015316278
New data, epoch 12
Step 0 Batch loss 0.44704905
Step 100 Batch loss 0.00019912994
Step 200 Batch loss 0.00017792913
Step 300 Batch loss 0.00016118737
Step 400 Batch loss 0.00013914079
Step 500 Batch loss 0.00014415495
Step 600 Batch loss 9.900126e-05
New data, epoch 13
Step 0 Batch loss 0.41073528
Step 100 Batch loss 0.0009897884
Step 200 Batch loss 0.00042810012
Step 300 Batch loss 0.00041239464
Step 400 Batch loss 0.00035769984
Step 500 Batch loss 0.0002505757
Step 600 Batch loss 0.0002042584
New data, epoch 14
Step 0 Batch loss 0.3625364
Step 100 Batch loss 0.0002875854
Step 200 Batch loss 0.00023715876
Step 300 Batch loss 0.00025518273
Step 400 Batch loss 0.00022542775
Step 500 Batch loss 0.00015973883
Step 600 Batch loss 0.00017446843
New data, epoch 15
Step 0 Batch loss 0.51592
Step 100 Batch loss 0.00022438513
Step 200 Batch loss 0.00019000995
Step 300 Batch loss 0.00017601653
Step 400 Batch loss 0.00014197198
Step 500 Batch loss 0.0001588454
Step 600 Batch loss 0.00016186092
New data, epoch 16
Step 0 Batch loss 0.33341405
Step 100 Batch loss 0.00018796202
Step 200 Batch loss 0.0001475322
Step 300 Batch loss 0.00013919313
Step 400 Batch loss 0.000121400655
Step 500 Batch loss 0.00012538415
Step 600 Batch loss 0.000126348
New data, epoch 17
Step 0 Batch loss 0.27615955
Step 100 Batch loss 0.00011413091
Step 200 Batch loss 0.00012616026
Step 300 Batch loss 0.0001170844
Step 400 Batch loss 0.00011075015
Step 500 Batch loss 9.677185e-05
Step 600 Batch loss 0.00010936171
New data, epoch 18
Step 0 Batch loss 0.27728036
Step 100 Batch loss 0.000120370074
Step 200 Batch loss 0.00012167257
Step 300 Batch loss 0.000115299474
Step 400 Batch loss 0.00010562373
Step 500 Batch loss 0.00010663352
Step 600 Batch loss 8.45541e-05
New data, epoch 19
Step 0 Batch loss 0.16948073
Step 100 Batch loss 8.913559e-05
Step 200 Batch loss 9.106041e-05
Step 300 Batch loss 8.5186235e-05
Step 400 Batch loss 8.568552e-05
Step 500 Batch loss 8.628632e-05
Step 600 Batch loss 8.9613175e-05
New data, epoch 20
Step 0 Batch loss 0.46787563
Step 100 Batch loss 0.0001193131
Step 200 Batch loss 9.625699e-05
Step 300 Batch loss 0.00010274808
Step 400 Batch loss 7.903248e-05
Step 500 Batch loss 8.2644736e-05
Step 600 Batch loss 7.527247e-05
New data, epoch 21
Step 0 Batch loss 0.46513498
Step 100 Batch loss 0.00010563305
Step 200 Batch loss 9.7159274e-05
Step 300 Batch loss 9.643067e-05
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-15-59557012e2c4> in <module>
     33             if batch_idx%100 == 0:
     34                 print("Step",batch_idx, "Batch loss", _total_loss)
---> 35                 plot(loss_list, _predictions_series, batchX, batchY)
     36 
     37 plt.ioff()

<ipython-input-5-a7a221093bf1> in plot(loss_list, predictions_series, batchX, batchY)
     17 
     18     plt.draw()
---> 19     plt.pause(0.0001)

~/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in pause(interval)
    292         if canvas.figure.stale:
    293             canvas.draw_idle()
--> 294         show(block=False)
    295         canvas.start_event_loop(interval)
    296     else:

~/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in show(*args, **kw)
    252     """
    253     global _show
--> 254     return _show(*args, **kw)
    255 
    256 

~/anaconda3/lib/python3.6/site-packages/ipykernel/pylab/backend_inline.py in show(close, block)
     37             display(
     38                 figure_manager.canvas.figure,
---> 39                 metadata=_fetch_figure_metadata(figure_manager.canvas.figure)
     40             )
     41     finally:

~/anaconda3/lib/python3.6/site-packages/IPython/core/display.py in display(include, exclude, metadata, transient, display_id, *objs, **kwargs)
    302             publish_display_data(data=obj, metadata=metadata, **kwargs)
    303         else:
--> 304             format_dict, md_dict = format(obj, include=include, exclude=exclude)
    305             if not format_dict:
    306                 # nothing to display (e.g. _ipython_display_ took over)

~/anaconda3/lib/python3.6/site-packages/IPython/core/formatters.py in format(self, obj, include, exclude)
    178             md = None
    179             try:
--> 180                 data = formatter(obj)
    181             except:
    182                 # FIXME: log the exception

<decorator-gen-9> in __call__(self, obj)

~/anaconda3/lib/python3.6/site-packages/IPython/core/formatters.py in catch_format_error(method, self, *args, **kwargs)
    222     """show traceback on failed format call"""
    223     try:
--> 224         r = method(self, *args, **kwargs)
    225     except NotImplementedError:
    226         # don't warn on NotImplementedErrors

~/anaconda3/lib/python3.6/site-packages/IPython/core/formatters.py in __call__(self, obj)
    339                 pass
    340             else:
--> 341                 return printer(obj)
    342             # Finally look for special method names
    343             method = get_real_method(obj, self.print_method)

~/anaconda3/lib/python3.6/site-packages/IPython/core/pylabtools.py in <lambda>(fig)
    242 
    243     if 'png' in formats:
--> 244         png_formatter.for_type(Figure, lambda fig: print_figure(fig, 'png', **kwargs))
    245     if 'retina' in formats or 'png2x' in formats:
    246         png_formatter.for_type(Figure, lambda fig: retina_figure(fig, **kwargs))

~/anaconda3/lib/python3.6/site-packages/IPython/core/pylabtools.py in print_figure(fig, fmt, bbox_inches, **kwargs)
    126 
    127     bytes_io = BytesIO()
--> 128     fig.canvas.print_figure(bytes_io, **kw)
    129     data = bytes_io.getvalue()
    130     if fmt == 'svg':

~/anaconda3/lib/python3.6/site-packages/matplotlib/backend_bases.py in print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, **kwargs)
   2051                     bbox_artists = kwargs.pop("bbox_extra_artists", None)
   2052                     bbox_inches = self.figure.get_tightbbox(renderer,
-> 2053                             bbox_extra_artists=bbox_artists)
   2054                     pad = kwargs.pop("pad_inches", None)
   2055                     if pad is None:

~/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py in get_tightbbox(self, renderer, bbox_extra_artists)
   2278                 try:
   2279                     bbox = ax.get_tightbbox(renderer,
-> 2280                             bbox_extra_artists=bbox_extra_artists)
   2281                 except TypeError:
   2282                     bbox = ax.get_tightbbox(renderer)

~/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_base.py in get_tightbbox(self, renderer, call_axes_locator, bbox_extra_artists)
   4394 
   4395         for a in bbox_artists:
-> 4396             bbox = a.get_tightbbox(renderer)
   4397             if (bbox is not None and
   4398                     (bbox.width != 0 or bbox.height != 0) and

~/anaconda3/lib/python3.6/site-packages/matplotlib/axis.py in get_tightbbox(self, renderer)
   1153         for a in [self.label, self.offsetText]:
   1154             bbox = a.get_window_extent(renderer)
-> 1155             if (np.isfinite(bbox.width) and np.isfinite(bbox.height) and
   1156                     a.get_visible()):
   1157                 bb.append(bbox)

~/anaconda3/lib/python3.6/site-packages/matplotlib/transforms.py in width(self)
    400         """
    401         points = self.get_points()
--> 402         return points[1, 0] - points[0, 0]
    403 
    404     @property

KeyboardInterrupt: 

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True) cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=0.5) cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

In [16]:
num_epochs = 10
total_series_length = 50000
truncated_backprop_length = 15
state_size = 4
num_classes = 2
echo_step = 3
batch_size = 5
num_batches = total_series_length//batch_size//truncated_backprop_length
num_layers = 3
In [6]:
batchX_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_backprop_length])
batchY_placeholder = tf.placeholder(tf.int32, [batch_size, truncated_backprop_length])

init_state = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
In [13]:
state_per_layer_list = tf.unstack(init_state, axis=0)
rnn_tuple_state = tuple(
    [tf.nn.rnn_cell.LSTMStateTuple(state_per_layer_list[idx][0], state_per_layer_list[idx][1])
     for idx in range(num_layers)]
)

W2 = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)
b2 = tf.Variable(np.zeros((1,num_classes)), dtype=tf.float32)
In [8]:
W = tf.Variable(np.random.rand(state_size+1, state_size), dtype=tf.float32)
b = tf.Variable(np.zeros((1,state_size)), dtype=tf.float32)

W2 = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)
b2 = tf.Variable(np.zeros((1,num_classes)), dtype=tf.float32)
In [15]:
# Forward passes
#cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True, reuse=True)
#cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=0.5)
cells = []
for l in range(num_layers):
    cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=Truewith tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    plt.ion()
    plt.figure()
    plt.show()
    loss_list = []

    for epoch_idx in range(num_epochs):
        x,y = generateData()

        _current_state = np.zeros((num_layers, 2, batch_size, state_size))

        print("New data, epoch", epoch_idx)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * truncated_backprop_length
            end_idx = start_idx + truncated_backprop_length

            batchX = x[:,start_idx:end_idx]
            batchY = y[:,start_idx:end_idx]

            _total_loss, _train_step, _current_state, _predictions_series = sess.run(
                [total_loss, train_step, current_state, predictions_series],
                feed_dict={
                    batchX_placeholder: batchX,
                    batchY_placeholder: batchY,
                    init_state: _current_state
                })


            loss_list.append(_total_loss)

            if batch_idx%100 == 0:
                print("Step",batch_idx, "Batch loss", _total_loss)
                plot(loss_list, _predictions_series, batchX, batchY)

plt.ioff()
plt.show())
    cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=0.5)
    cells.append(cell)
cell = tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True)
states_series, current_state = tf.nn.dynamic_rnn(cell, tf.expand_dims(batchX_placeholder, -1),
                                                 initial_state=rnn_tuple_state)
states_series = tf.reshape(states_series, [-1, state_size])

logits = tf.matmul(states_series, W2) + b2 #Broadcasted addition
labels = tf.reshape(batchY_placeholder, [-1])

logits_series = tf.unstack(tf.reshape(logits, [batch_size, truncated_backprop_length, 2]), axis=1)
predictions_series = [tf.nn.softmax(logit) for logit in logits_series]

losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
total_loss = tf.reduce_mean(losses)

train_step = tf.train.AdagradOptimizer(0.3).minimize(total_loss)
In [17]:
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    plt.ion()
    plt.figure()
    plt.show()
    loss_list = []

    for epoch_idx in range(num_epochs):
        x,y = generateData()

        _current_state = np.zeros((num_layers, 2, batch_size, state_size))

        print("New data, epoch", epoch_idx)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * truncated_backprop_length
            end_idx = start_idx + truncated_backprop_length

            batchX = x[:,start_idx:end_idx]
            batchY = y[:,start_idx:end_idx]

            _total_loss, _train_step, _current_state, _predictions_series = sess.run(
                [total_loss, train_step, current_state, predictions_series],
                feed_dict={
                    batchX_placeholder: batchX,
                    batchY_placeholder: batchY,
                    init_state: _current_state
                })


            loss_list.append(_total_loss)

            if batch_idx%100 == 0:
                print("Step",batch_idx, "Batch loss", _total_loss)
                plot(loss_list, _predictions_series, batchX, batchY)

plt.ioff()
plt.show()
WARNING:tensorflow:From /root/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/tf_should_use.py:189: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
<Figure size 432x288 with 0 Axes>
New data, epoch 0
Step 0 Batch loss 0.6929688
Step 100 Batch loss 0.6954956
Step 200 Batch loss 0.7002546
Step 300 Batch loss 0.57871366
Step 400 Batch loss 0.5939695
Step 500 Batch loss 0.60475427
Step 600 Batch loss 0.5576292
New data, epoch 1
Step 0 Batch loss 0.4871638
Step 100 Batch loss 0.41920424
Step 200 Batch loss 0.31294614
Step 300 Batch loss 0.42661774
Step 400 Batch loss 0.20004034
Step 500 Batch loss 0.25433797
Step 600 Batch loss 0.2032494
New data, epoch 2
Step 0 Batch loss 0.43561286
Step 100 Batch loss 0.16278562
Step 200 Batch loss 0.19558659
Step 300 Batch loss 0.30000505
Step 400 Batch loss 0.16154401
Step 500 Batch loss 0.1886921
Step 600 Batch loss 0.109658815
New data, epoch 3
Step 0 Batch loss 0.6052768
Step 100 Batch loss 0.13835937
Step 200 Batch loss 0.17098038
Step 300 Batch loss 0.13769978
Step 400 Batch loss 0.10130726
Step 500 Batch loss 0.1967591
Step 600 Batch loss 0.2203034
New data, epoch 4
Step 0 Batch loss 0.46089447
Step 100 Batch loss 0.097451195
Step 200 Batch loss 0.41533235
Step 300 Batch loss 0.15700059
Step 400 Batch loss 0.21913075
Step 500 Batch loss 0.19718121
Step 600 Batch loss 0.087576464
New data, epoch 5
Step 0 Batch loss 0.34818193
Step 100 Batch loss 0.10820887
Step 200 Batch loss 0.15518722
Step 300 Batch loss 0.15856141
Step 400 Batch loss 0.1631305
Step 500 Batch loss 0.07454713
Step 600 Batch loss 0.122998826
New data, epoch 6
Step 0 Batch loss 0.38226745
Step 100 Batch loss 0.22846888
Step 200 Batch loss 0.09558049
Step 300 Batch loss 0.13415973
Step 400 Batch loss 0.1253299
Step 500 Batch loss 0.13484524
Step 600 Batch loss 0.12309661
New data, epoch 7
Step 0 Batch loss 0.49858776
Step 100 Batch loss 0.17712341
Step 200 Batch loss 0.08960713
Step 300 Batch loss 0.20791616
Step 400 Batch loss 0.14790417
Step 500 Batch loss 0.13240221
Step 600 Batch loss 0.21758533
New data, epoch 8
Step 0 Batch loss 0.45786998
Step 100 Batch loss 0.17983684
Step 200 Batch loss 0.11675018
Step 300 Batch loss 0.15271305
Step 400 Batch loss 0.21675894
Step 500 Batch loss 0.14125304
Step 600 Batch loss 0.081797935
New data, epoch 9
Step 0 Batch loss 0.42411888
Step 100 Batch loss 0.08109849
Step 200 Batch loss 0.13136278
Step 300 Batch loss 0.15327615
Step 400 Batch loss 0.13441864
Step 500 Batch loss 0.17704493
Step 600 Batch loss 0.28483298